using Pkg
Pkg.activate(".")
Pkg.resolve()
Pkg.instantiate()

using Distributed
nprocs() == 1 && addprocs(8)
using DataFrames
using DataFramesMeta
using CSV
using StatsBase
using Statistics
using SplitApplyCombine
using Missings
using BenchmarkTools
using Distributions
using Optim
using JSON
using LaTeXTabulars
using LaTeXStrings
using ForwardDiff
using LinearAlgebra: diag
using LineSearches
using BlackBoxOptim
using Serialization
using Revise


################################################
#%% Check and create necessary directories
################################################
# Define the directories to check
directories = ["saved_objects", "snippets", "temp_saves"]

# Check if each directory exists, create it if it doesn't
for dir in directories
    if !isdir(dir)
        println("Creating directory: $dir")
        mkdir(dir)
    else
        println("Directory already exists: $dir")
    end
end


################################################
#%% Load data
################################################
dropped = DataFrame(CSV.File("dropped_subjects_section4paragraph2.csv", header=false))[!,1]
data1 = CSV.File("choices_exp1_inclZalloc.csv") |> DataFrame
data1 = @subset(data1, :sid .∉ [dropped])

data2 = CSV.File("choices_exp2_inclZalloc.csv") |> DataFrame
data2 = @subset(data2, :sid .∉ [dropped])


X = deepcopy(data1)
X[!,:self_z] = parse.(Int64, replace.(X[!,:self_z], "NA" => "0"))
X[!,:other_z] = parse.(Int64, replace.(X[!,:other_z], "NA" => "0"))
X[!,:type] .= 1


X2 = deepcopy(data2)
X2[!,:self_z] = parse.(Int64, replace.(X2[!,:self_z], "NA" => "0"))
X2[!,:other_z] = parse.(Int64, replace.(X2[!,:other_z], "NA" => "0"))
X2[!,:type] .= 1

X_dg = @subset(X, :dg .== 1)

X2_dg = @subset(X2, :dg .== 1)
################################################
#%% Load functions
################################################
include("PreferenceModels.jl")
import .PreferenceModels
PM = PreferenceModels


@everywhere begin
    using Pkg
    Pkg.activate(".")
end

@everywhere include("PreferenceModels.jl")
@everywhere begin
    import .PreferenceModels as PM
end

@everywhere begin
    function save_JSON(d, path)
        open(path, "w") do f
            write(f, JSON.json(d))
        end
    end

    function assure_bouds(x, lb, ub)
        for i in eachindex(x)
            x[i] = x[i] < lb[i] ? lb[i] + 0.01 : x[i]
            x[i] = x[i] > ub[i] ? ub[i] - 0.01 : x[i]
        end
        x
    end

    nunique(x) = length(unique(x))

    function read_JSON(path)
        dict=JSON.parsefile(path)  # parse and transform data
        dict
    end
    using DataFrames
    using DataFramesMeta
    using Optim
    using Distributed
    using StatsBase:mean
    using LineSearches
    using LinearAlgebra: diag
    using ForwardDiff
    using BlackBoxOptim
end


#%%
@everywhere begin
    m_fsr = PM.MixModel([PM.FSR()])
    m_3fsr = PM.MixModel([PM.FSR(), PM.FSR(), PM.FSR()])

    glob = Symbol[]
    m1 = PM.MixModel([PM.Decency()], glob)
    m3_glob_abg = PM.MixModel([PM.Decency(), PM.Decency(), PM.Decency()], [:α, :β, :γ])
    X = $X
    X2 = $X2
end


#%%
function opt_fun(X, m; bb=true, nm=true, glob_opt="all", paralell=true, bb_time=60, nm_time=60, cont_time=60, use_jump=false, pop_size=200)
    mapf = pmap
    if paralell == false
        mapf = map
    end
    res_vec = mapf([(X, m) for i in 1:nprocs()]) do (X,m)
        lb = PM.lower(m)
        ub = PM.upper(m)
        if bb
            opt_bb = PM.opt_full_model(X, deepcopy(m); opt_f="SAMIN", time_limit=bb_time)
            x_bb = PM.get_params_for_opt(opt_bb)
            x_bb = assure_bouds(x_bb, lb, ub)
            opt_bb = PM.set_params_from_opt!(x_bb, opt_bb)
            bb_perf = -PM.perf(X, opt_bb)
            println("BB opt performance : $(PM.perf(X, opt_bb))")
            asssign_first = true
        else
            opt_bb = deepcopy(m)
            x_bb = PM.get_params_for_opt(m)
            asssign_first = false
        end

        if nm
            opt_nm = PM.opt_full_model(X, deepcopy(opt_bb); opt_f="NM", time_limit=nm_time)
            x_nm = PM.get_params_for_opt(opt_nm)
            x_nm = assure_bouds(x_nm, lb, ub)
            opt_nm = PM.set_params_from_opt!(x_nm, opt_nm)
            bb_perf = -PM.perf(X, opt_bb)
            nm_perf = -PM.perf(X, opt_nm)
            if bb_perf < nm_perf
                opt_nm = deepcopy(opt_bb)
                x_nm = deepcopy(x_bb)
            end
        else
            opt_nm = deepcopy(opt_bb)
            x_nm = PM.get_params_for_opt(opt_nm)
        end

        if use_jump
            opt_f = PM.gen_opt_fun(;time_limit=cont_time, opt_f="Ipopt")
            opt_cont = PM.opt_full_model_Jump(X, deepcopy(opt_nm); optimizer_fun=opt_f)
            x_cont = PM.get_params_for_opt(opt_cont)
        else
            opt_cont = PM.opt_full_model(X, deepcopy(opt_nm); opt_f="LBFGS", time_limit=cont_time)
            x_cont = PM.get_params_for_opt(opt_cont)
        end

        (;perf_bb=-PM.perf(X, opt_bb), x_bb=x_bb, perf_nm=-PM.perf(X, opt_nm), x_nm=x_nm, perf_cont=-PM.perf(X, opt_cont), x_cont=x_cont,)
    end

    bb_perfs = [res[:perf_bb] for res in res_vec]
    nm_perfs = [res[:perf_nm] for res in res_vec]
    cont_perfs = [res[:perf_cont] for res in res_vec]

    cont_min = minimum(cont_perfs)
    nm_min = minimum(nm_perfs)
    bb_min = minimum(bb_perfs)

    min_idx = argmin(cont_perfs)
    opt_x = res_vec[min_idx][:x_cont]
    opt_m = PM.set_params_from_opt!(opt_x, m)
    opt_perf = PM.perf(X, opt_m)

    return (; opt_m, opt_x, opt_perf, res_vec)
end


function save_fun(opt_1, opt_2, namn; X=X, X2=X2, save_params = true, only_params=false, p1=nothing, p2=nothing, accu1=nothing, accu2=nothing)
    if save_params
        save_JSON(PM.get_params_for_opt(opt_1), "temp_saves/params_$(namn)_1.json")
        save_JSON(PM.get_params_for_opt(opt_2), "temp_saves/params_$(namn)_2.json")
    end

    p1 = isnothing(p1)  ? PM.perf(X, opt_1) : p1
    accu1 = isnothing(accu1) ? PM.accuracy(X, opt_1) : accu1
    names1 = PM.get_fields(opt_1.types[1])
    params1 = PM.get_params(opt_1)

    if opt_1.K == 1
        row1 = Dict{Any,Any}(:session=>1, :ll=>round(p1,digits=1), :accu=>round(accu1,digits=3))
        for (name, param) in zip(names1, params1)
            row1[name] = round(param, digits=3)
        end
        df1 = DataFrame(row1)
    else
        row1 = Dict{Any,Any}(:session=>1, :ll=>round(p1,digits=1), :accu=>round(accu1,digits=3))
        params_1 = PM.get_params(opt_1.types[1])
        for (name, param) in zip(names1, params_1)
            row1[name] = round(param, digits=3)
        end
        row1[:share] = round(opt_1.shares[1], digits=3)
        add_rows = map(2:opt_1.K) do k
            row = Dict{Any,Any}(:session=>"", :ll=>"", :accu=>"")
            params = PM.get_params(opt_1.types[k])
            for (name, param) in zip(names1, params)
                row[name] = round(param, digits=3)
            end
            row[:share] = round(opt_1.shares[k], digits=3)
            row
        end
        dfs = [DataFrame([NamedTuple{Tuple(keys(d))}(values(d))]) for d in [row1, add_rows...]]
        df1 = vcat(dfs..., cols=:union)
    end

    p2 = p2 === nothing ? PM.perf(X2, opt_2) : p2
    accu2 = accu2 === nothing ? PM.accuracy(X2, opt_2) : accu2
    names2 = PM.get_fields(opt_2.types[1])
    params2 = PM.get_params(opt_2)

    if opt_2.K == 1
        row1 = Dict{Any,Any}(:session=>2, :ll=>round(p2,digits=1), :accu=>round(accu2,digits=3))
        for (name, param) in zip(names2, params2)
            row1[name] = round(param, digits=3)
        end
        df2 = DataFrame(row1)
    else
        row1 = Dict{Any,Any}(:session=>2, :ll=>round(p2,digits=1), :accu=>round(accu2,digits=3))
        params_2 = PM.get_params(opt_2.types[1])
        for (name, param) in zip(names2, params_2)
            row1[name] = round(param, digits=3)
        end
        row1[:share] = round(opt_2.shares[1],digits=3)
        add_rows = map(2:opt_2.K) do k
            row = Dict{Any,Any}(:session=>"", :ll=>"", :accu=>"")
            params = PM.get_params(opt_2.types[k])
            for (name, param) in zip(names2, params)
                row[name] = round(param, digits=3)
            end
            row[:share] = round(opt_2.shares[k], digits=3)
            row
        end
        dfs = [DataFrame([NamedTuple{Tuple(keys(d))}(values(d))]) for d in [row1, add_rows...]]
        df2 = vcat(dfs..., cols=:union)
    end

    df = vcat(df1, df2)
    only_params = false
    if only_params
        param_names_title = filter(x -> x ∉ [:session, :ll, :accu, :share], Symbol.(names(df)))
        if opt_1.K > 1
            df = df[:, [:session, :share, param_names_title...]]
            align = "c|c"* *(["c" for c in param_names_title]...)
        else
            align = "c|"* *(["c" for c in param_names_title]...)
            df = df[:, [:session, param_names_title...]]
        end
        names(df)
        title_namn = replace(names(df), "α" => L"\alpha", "β" => L"\beta", "γ" => L"\gamma" ,"δ" => L"\delta", "λ" => L"\lambda", "κ"=>L"\kappa", "π"=>L"\pi", "μ"=>L"\mu")
    else
        param_names_title = filter(x -> x ∉ [:session, :ll, :accu, :share], Symbol.(names(df)))
        if opt_1.K > 1
            df = df[:, [:session, :ll, :accu, :share, param_names_title...]]
            align = "ccc|c"* *(["c" for c in param_names_title]...)
        else
            align = "ccc|"* *(["c" for c in param_names_title]...)
            df = df[:, [:session, :ll, :accu, param_names_title...]]
        end
        names(df)
        title_namn = replace(names(df), "α" => L"\alpha", "β" => L"\beta", "γ" => L"\gamma" ,"δ" => L"\delta", "λ" => L"\lambda", "κ"=>L"\kappa", "π"=>L"\pi", "μ"=>L"\mu")
    end
    for col in names(df)
        df[!,col] .= string.(df[!,col])
        df[df[!,col] .== "missing", col] .= ""
    end

    title_namn = replace(title_namn, "accu" => "Accu", "ll" => "Ll", "shares" => "Shares", "share" => "Share", "session"=>"Session")
    print(df[!, "session"])
    latex_tabular("snippets/$(namn)_perfs.tex",Tabular(align),
              [Rule(:top),
              title_namn,
               Rule(:mid),
               Matrix(df[1:opt_1.K,:]),
               Rule(:mid),
               Matrix(df[opt_1.K+1:end,:]),
               Rule(:bottom)])
    df
end

# function run_save(m, namn;X=X, X2=X2, bb=true, nm=true, glob_opt="all", paralell=true, bb_time=3600, nm_time=600, cont_time=600, use_jump=false)
# TODO:  change back to values above
function run_save(m, namn;X=X, X2=X2, bb=true, nm=true, glob_opt="all", paralell=true, bb_time=60, nm_time=60, cont_time=60, use_jump=false)
    res_1 = opt_fun(X, deepcopy(m); bb=bb, glob_opt=glob_opt, nm=nm, paralell=paralell, bb_time=bb_time, nm_time=nm_time, cont_time=cont_time, use_jump=use_jump)
    opt_1 = res_1[:opt_m]
    res_2 = opt_fun(X2, deepcopy(m); bb=bb, glob_opt=glob_opt, nm=nm, paralell=paralell, bb_time=bb_time, nm_time=nm_time, cont_time=cont_time, use_jump=use_jump)
    opt_2 = res_2[:opt_m]
    save_fun(opt_1, opt_2, namn;X=X, X2=X2)
    (;res_1, res_2)
end

function load_resave(m, load_namn; X=X, X2=X2, save_namn=nothing, only_params=false)
    opt_1 = PM.set_params_from_opt!(read_JSON("temp_saves/params_$(load_namn)_1.json"), deepcopy(m))
    opt_2 = PM.set_params_from_opt!(read_JSON("temp_saves/params_$(load_namn)_2.json"), deepcopy(m))
    if save_namn === nothing
        save_namn = load_namn
    end
    save_fun(opt_1, opt_2, save_namn; X=X, X2=X2, save_params=false, only_params=only_params)
    
end

function load(m, load_namn, idx)
    opt = PM.set_params_from_opt!(read_JSON("temp_saves/params_$(load_namn)_$(idx).json"), deepcopy(m))
    opt
end


#%%
println("Running 3fsr_dg")
res_3fsr_dg = run_save(m_3fsr, "3fsr_dg"; X=X_dg, X2=X2_dg)
serialize("saved_objects/res_3fsr_dg", res_3fsr_dg)

println("Running m3_glob_abg_dg")
res_m3_glob_abg_dg = run_save(m3_glob_abg, "m3_glob_abg_dg"; X=X_dg, X2=X2_dg)
serialize("saved_objects/res_m3_glob_abg_dg", res_m3_glob_abg_dg)
